{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/mrdbourke/pytorch-deep-learning/blob/main/going_modular/05_pytorch_going_modular_script_mode.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "# 05. Going Modular: Part 2 (script mode)\n",
    "\n",
    "This notebook is part 2/2 of section [05. Going Modular](https://www.learnpytorch.io/05_pytorch_going_modular/).\n",
    "\n",
    "For reference, the two parts are: \n",
    "1. [**05. Going Modular: Part 1 (cell mode)**](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/going_modular/05_pytorch_going_modular_cell_mode.ipynb) - this notebook is run as a traditional Jupyter Notebook/Google Colab notebook and is a condensed version of [notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/).\n",
    "2. [**05. Going Modular: Part 2 (script mode)**](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/going_modular/05_pytorch_going_modular_script_mode.ipynb) - this notebook is the same as number 1 but with added functionality to turn each of the major sections into Python scripts, such as, `data_setup.py` and `train.py`. \n",
    "\n",
    "Why two parts?\n",
    "\n",
    "Because sometimes the best way to learn something is to see how it *differs* from something else.\n",
    "\n",
    "If you run each notebook side-by-side you'll see how they differ and that's where the key learnings are."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "## What is script mode?\n",
    "\n",
    "**Script mode** uses [Jupyter Notebook cell magic](https://ipython.readthedocs.io/en/stable/interactive/magics.html) (special commands) to turn specific cells into Python scripts.\n",
    "\n",
    "For example if you run the following code in a cell, you'll create a Python file called `hello_world.py`:\n",
    "\n",
    "```\n",
    "%%writefile hello_world.py\n",
    "print(\"hello world, machine learning is fun!\")\n",
    "```\n",
    "\n",
    "You could then run this Python file on the command line with:\n",
    "\n",
    "```\n",
    "python hello_world.py\n",
    "\n",
    ">>> hello world, machine learning is fun!\n",
    "```\n",
    "\n",
    "The main cell magic we're interested in using is `%%writefile`.\n",
    "\n",
    "Putting `%%writefile filename` at the top of a cell in Jupyter or Google Colab will write the contents of that cell to a specified `filename`.\n",
    "\n",
    "> **Question:** Do I have to create Python files like this? Can't I just start directly with a Python file and skip using a Google Colab notebook?\n",
    ">\n",
    "> **Answer:** Yes. This is only *one* way of creating Python scripts. If you know the kind of script you'd like to write, you could start writing it straight away. But since using Jupyter/Google Colab notebooks is a popular way of starting off data science and machine learning projects, knowing about the `%%writefile` magic command is a handy tip. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "## What has script mode got to do with PyTorch?\n",
    "\n",
    "If you've written some useful code in a Jupyter Notebook or Google Colab notebook, chances are you'll want to use that code again.\n",
    "\n",
    "And turning your useful cells into Python scripts (`.py` files) means you can use specific pieces of your code in other projects.\n",
    "\n",
    "This practice is not PyTorch specific.\n",
    "\n",
    "But it's how you'll see many different online PyTorch repositories structured."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "### PyTorch in the wild\n",
    "\n",
    "For example, if you find a PyTorch project on GitHub, it may be structured in the following way:\n",
    "\n",
    "```\n",
    "pytorch_project/\n",
    "├── pytorch_project/\n",
    "│   ├── data_setup.py\n",
    "│   ├── engine.py\n",
    "│   ├── model.py\n",
    "│   ├── train.py\n",
    "│   └── utils.py\n",
    "├── models/\n",
    "│   ├── model_1.pth\n",
    "│   └── model_2.pth\n",
    "└── data/\n",
    "    ├── data_folder_1/\n",
    "    └── data_folder_2/\n",
    "```\n",
    "\n",
    "Here, the top level directory is called `pytorch_project` but you could call it whatever you want.\n",
    "\n",
    "Inside there's another directory called `pytorch_project` which contains several `.py` files, the purposes of these may be:\n",
    "* `data_setup.py` - a file to prepare data (and download data if needed).\n",
    "* `engine.py` - a file containing various training functions.\n",
    "* `model_builder.py` or `model.py` - a file to create a PyTorch model.\n",
    "* `train.py` - a file to leverage all other files and train a target PyTorch model.\n",
    "* `utils.py` - a file dedicated to helpful utility functions.\n",
    "\n",
    "And the `models` and `data` directories could hold PyTorch models and data files respectively (though due to the size of models and data files, it's unlikely you'll find the *full* versions of these on GitHub, these directories are present above mainly for demonstration purposes).\n",
    "\n",
    "> **Note:** There are many different ways to structure a Python project and subsequently a PyTorch project. This isn't a guide on *how* to structure your projects, only an example of how you *might* come across PyTorch projects in the wild. For more on structuring Python projects, see Real Python's [*Python Application Layouts: A Reference*](https://realpython.com/python-application-layouts/) guide."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "## What's the difference between this notebook (Part 2) and the cell mode notebook (Part 1)?\n",
    "\n",
    "This notebook, 05 Going Modular: Part 2 (script mode), creates Python scripts out of the cells created in part 1.\n",
    "\n",
    "Running this notebook end-to-end will result in having a directory structure very similar to the `pytorch_project` structure above.\n",
    "\n",
    "You'll notice each section in Part 2 (script mode) has an extra subsection (e.g. 2.1, 3.1, 4.1) for turning cell code into script code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "## What we're going to cover\n",
    "\n",
    "By the end of this notebook you should finish with a directory structure of: \n",
    "\n",
    "```\n",
    "going_modular/\n",
    "├── going_modular/\n",
    "│   ├── data_setup.py\n",
    "│   ├── engine.py\n",
    "│   ├── model_builder.py\n",
    "│   ├── train.py\n",
    "│   └── utils.py\n",
    "├── models/\n",
    "│   ├── 05_going_modular_cell_mode_tinyvgg_model.pth\n",
    "│   └── 05_going_modular_script_mode_tinyvgg_model.pth\n",
    "└── data/\n",
    "    └── pizza_steak_sushi/\n",
    "        ├── train/\n",
    "        │   ├── pizza/\n",
    "        │   │   ├── image01.jpeg\n",
    "        │   │   └── ...\n",
    "        │   ├── steak/\n",
    "        │   └── sushi/\n",
    "        └── test/\n",
    "            ├── pizza/\n",
    "            ├── steak/\n",
    "            └── sushi/\n",
    "```\n",
    "\n",
    "Using this directory structure, you should be able to train a model from within a notebook with the command:\n",
    "\n",
    "```\n",
    "!python going_modular/train.py\n",
    "```\n",
    "\n",
    "Or from the command line with:\n",
    "\n",
    "```\n",
    "python going_modular/train.py\n",
    "```\n",
    "\n",
    "In essence, we will have turned our helpful notebook code into **reusable modular code**."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M5tsObEa17-",
    "tags": []
   },
   "source": [
    "## Where can you get help?\n",
    "\n",
    "You can find the book version of this section [05. PyTorch Going Modular on learnpytorch.io](https://www.learnpytorch.io/05_pytorch_going_modular/).\n",
    "\n",
    "The rest of the materials for this course [are available on GitHub](https://github.com/mrdbourke/pytorch-deep-learning).\n",
    "\n",
    "If you run into trouble, you can ask a question on the course [GitHub Discussions page](https://github.com/mrdbourke/pytorch-deep-learning/discussions).\n",
    "\n",
    "And of course, there's the [PyTorch documentation](https://pytorch.org/docs/stable/index.html) and [PyTorch developer forums](https://discuss.pytorch.org/), a very helpful place for all things PyTorch. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OW_3D_fOap2I"
   },
   "source": [
    "## 0. Creating a folder for storing Python scripts\n",
    "\n",
    "Since we're going to be creating Python scripts out of our most useful code cells, let's create a folder for storing those scripts.\n",
    "\n",
    "We'll call the folder `going_modular` and create it using Python's [`os.makedirs()`](https://docs.python.org/3/library/os.html) method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "K29qwItveNm9"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.makedirs(\"going_modular\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1XgXFr97bCsI"
   },
   "source": [
    "## 1. Get data\n",
    "\n",
    "We're going to start by downloading the same data we used in [notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/#1-get-data), the `pizza_steak_sushi` dataset with images of pizza, steak and sushi."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wl_jfPzVbGDS",
    "outputId": "51a3e5e8-991c-4309-b1ba-b24d32bbdb61"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/pizza_steak_sushi directory exists.\n",
      "Downloading pizza, steak, sushi data...\n",
      "Unzipping pizza, steak, sushi data...\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import zipfile\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "import requests\n",
    "\n",
    "# Setup path to data folder\n",
    "data_path = Path(\"data/\")\n",
    "image_path = data_path / \"pizza_steak_sushi\"\n",
    "\n",
    "# If the image folder doesn't exist, download it and prepare it... \n",
    "if image_path.is_dir():\n",
    "    print(f\"{image_path} directory exists.\")\n",
    "else:\n",
    "    print(f\"Did not find {image_path} directory, creating one...\")\n",
    "    image_path.mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "# Download pizza, steak, sushi data\n",
    "with open(data_path / \"pizza_steak_sushi.zip\", \"wb\") as f:\n",
    "    request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip\")\n",
    "    print(\"Downloading pizza, steak, sushi data...\")\n",
    "    f.write(request.content)\n",
    "\n",
    "# Unzip pizza, steak, sushi data\n",
    "with zipfile.ZipFile(data_path / \"pizza_steak_sushi.zip\", \"r\") as zip_ref:\n",
    "    print(\"Unzipping pizza, steak, sushi data...\") \n",
    "    zip_ref.extractall(image_path)\n",
    "\n",
    "# Remove zip file\n",
    "os.remove(data_path / \"pizza_steak_sushi.zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8qYMdiBMbde0",
    "outputId": "a5093c72-455f-471c-9e95-5b391eea9ebf"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(PosixPath('data/pizza_steak_sushi/train'),\n",
       " PosixPath('data/pizza_steak_sushi/test'))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Setup train and testing paths\n",
    "train_dir = image_path / \"train\"\n",
    "test_dir = image_path / \"test\"\n",
    "\n",
    "train_dir, test_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eYL6NynybU9Z"
   },
   "source": [
    "## 2. Create Datasets and DataLoaders\n",
    "\n",
    "Let's turn our data into PyTorch `Dataset`'s and `DataLoader`'s and find out a few useful attributes from them such as `classes` and their lengths. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DaOJOd-QbOQk",
    "outputId": "dc7168b3-8bee-40c4-867e-58aa058175b5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train data:\n",
      "Dataset ImageFolder\n",
      "    Number of datapoints: 225\n",
      "    Root location: data/pizza_steak_sushi/train\n",
      "    StandardTransform\n",
      "Transform: Compose(\n",
      "               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=None)\n",
      "               ToTensor()\n",
      "           )\n",
      "Test data:\n",
      "Dataset ImageFolder\n",
      "    Number of datapoints: 75\n",
      "    Root location: data/pizza_steak_sushi/test\n",
      "    StandardTransform\n",
      "Transform: Compose(\n",
      "               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=None)\n",
      "               ToTensor()\n",
      "           )\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "# Create simple transform\n",
    "data_transform = transforms.Compose([ \n",
    "    transforms.Resize((64, 64)),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "# Use ImageFolder to create dataset(s)\n",
    "train_data = datasets.ImageFolder(root=train_dir, # target folder of images\n",
    "                                  transform=data_transform, # transforms to perform on data (images)\n",
    "                                  target_transform=None) # transforms to perform on labels (if necessary)\n",
    "\n",
    "test_data = datasets.ImageFolder(root=test_dir, \n",
    "                                 transform=data_transform)\n",
    "\n",
    "print(f\"Train data:\\n{train_data}\\nTest data:\\n{test_data}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "eEln0VFmbfyY",
    "outputId": "c514574e-e3ab-4689-a5ee-6a2aef6a11d2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['pizza', 'steak', 'sushi']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get class names as a list\n",
    "class_names = train_data.classes\n",
    "class_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "IMGTFVr8bgwE",
    "outputId": "4268d11f-8899-447a-baba-7e91ade01f13"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pizza': 0, 'steak': 1, 'sushi': 2}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Can also get class names as a dict\n",
    "class_dict = train_data.class_to_idx\n",
    "class_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6a8IW-IYbhGX",
    "outputId": "87bc60ac-3be3-499f-e0a8-aed5964a3075"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(225, 75)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check the lengths\n",
    "len(train_data), len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "f1UxnvvmblUj",
    "outputId": "1dd0c011-912e-498b-ea34-ee4bf05c7b87"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<torch.utils.data.dataloader.DataLoader at 0x7fca2e344760>,\n",
       " <torch.utils.data.dataloader.DataLoader at 0x7fca2e3445b0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Turn train and test Datasets into DataLoaders\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(dataset=train_data, \n",
    "                              batch_size=1, # how many samples per batch?\n",
    "                              num_workers=1, # how many subprocesses to use for data loading? (higher = more)\n",
    "                              shuffle=True) # shuffle the data?\n",
    "\n",
    "test_dataloader = DataLoader(dataset=test_data, \n",
    "                             batch_size=1, \n",
    "                             num_workers=1, \n",
    "                             shuffle=False) # don't usually need to shuffle testing data\n",
    "\n",
    "train_dataloader, test_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JJKRqXNGbnI5",
    "outputId": "2d5320fc-ac72-4ef8-94f8-0ef283fc3a30"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image shape: torch.Size([1, 3, 64, 64]) -> [batch_size, color_channels, height, width]\n",
      "Label shape: torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "# Check out single image size/shape\n",
    "img, label = next(iter(train_dataloader))\n",
    "\n",
    "# Batch size will now be 1, try changing the batch_size parameter above and see what happens\n",
    "print(f\"Image shape: {img.shape} -> [batch_size, color_channels, height, width]\")\n",
    "print(f\"Label shape: {label.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RjINjcyGdwL9"
   },
   "source": [
    "### 2.1 Create Datasets and DataLoaders (script mode)\n",
    "\n",
    "Rather than rewriting all of the code above everytime we wanted to load data, we can turn it into a script called `data_setup.py`.\n",
    "\n",
    "Let's capture all of the above functionality into a function called `create_dataloaders()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "CmjK-b04du_-",
    "outputId": "07d3450d-bc24-4aa4-a6d6-57d60bf11416"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting going_modular/data_setup.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile going_modular/data_setup.py\n",
    "\"\"\"\n",
    "Contains functionality for creating PyTorch DataLoaders for \n",
    "image classification data.\n",
    "\"\"\"\n",
    "import os\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "NUM_WORKERS = os.cpu_count()\n",
    "\n",
    "def create_dataloaders(\n",
    "    train_dir: str, \n",
    "    test_dir: str, \n",
    "    transform: transforms.Compose, \n",
    "    batch_size: int, \n",
    "    num_workers: int=NUM_WORKERS\n",
    "):\n",
    "  \"\"\"Creates training and testing DataLoaders.\n",
    "\n",
    "  Takes in a training directory and testing directory path and turns\n",
    "  them into PyTorch Datasets and then into PyTorch DataLoaders.\n",
    "\n",
    "  Args:\n",
    "    train_dir: Path to training directory.\n",
    "    test_dir: Path to testing directory.\n",
    "    transform: torchvision transforms to perform on training and testing data.\n",
    "    batch_size: Number of samples per batch in each of the DataLoaders.\n",
    "    num_workers: An integer for number of workers per DataLoader.\n",
    "\n",
    "  Returns:\n",
    "    A tuple of (train_dataloader, test_dataloader, class_names).\n",
    "    Where class_names is a list of the target classes.\n",
    "    Example usage:\n",
    "      train_dataloader, test_dataloader, class_names = \\\n",
    "        = create_dataloaders(train_dir=path/to/train_dir,\n",
    "                             test_dir=path/to/test_dir,\n",
    "                             transform=some_transform,\n",
    "                             batch_size=32,\n",
    "                             num_workers=4)\n",
    "  \"\"\"\n",
    "  # Use ImageFolder to create dataset(s)\n",
    "  train_data = datasets.ImageFolder(train_dir, transform=transform)\n",
    "  test_data = datasets.ImageFolder(test_dir, transform=transform)\n",
    "\n",
    "  # Get class names\n",
    "  class_names = train_data.classes\n",
    "\n",
    "  # Turn images into data loaders\n",
    "  train_dataloader = DataLoader(\n",
    "      train_data,\n",
    "      batch_size=batch_size,\n",
    "      shuffle=True,\n",
    "      num_workers=num_workers,\n",
    "      pin_memory=True,\n",
    "  )\n",
    "  test_dataloader = DataLoader(\n",
    "      test_data,\n",
    "      batch_size=batch_size,\n",
    "      shuffle=False,\n",
    "      num_workers=num_workers,\n",
    "      pin_memory=True,\n",
    "  )\n",
    "\n",
    "  return train_dataloader, test_dataloader, class_names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gAQLp1dCbrY0"
   },
   "source": [
    "## 3. Making a model (TinyVGG) \n",
    "\n",
    "We're going to use the same model we used in notebook 04: TinyVGG from the CNN Explainer website.\n",
    "\n",
    "The only change here from notebook 04 is that a docstring has been added using [Google's Style Guide for Python](https://google.github.io/styleguide/pyguide.html#384-classes). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "Z6i2xWcwbvom"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from torch import nn \n",
    "\n",
    "class TinyVGG(nn.Module):\n",
    "    \"\"\"Creates the TinyVGG architecture.\n",
    "\n",
    "    Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.\n",
    "    See the original architecture here: https://poloclub.github.io/cnn-explainer/\n",
    "\n",
    "    Args:\n",
    "    input_shape: An integer indicating number of input channels.\n",
    "    hidden_units: An integer indicating number of hidden units between layers.\n",
    "    output_shape: An integer indicating number of output units.\n",
    "    \"\"\"\n",
    "    def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:\n",
    "        super().__init__()\n",
    "        self.conv_block_1 = nn.Sequential(\n",
    "          nn.Conv2d(in_channels=input_shape, \n",
    "                    out_channels=hidden_units, \n",
    "                    kernel_size=3, \n",
    "                    stride=1, \n",
    "                    padding=0),  \n",
    "          nn.ReLU(),\n",
    "          nn.Conv2d(in_channels=hidden_units, \n",
    "                    out_channels=hidden_units,\n",
    "                    kernel_size=3,\n",
    "                    stride=1,\n",
    "                    padding=0),\n",
    "          nn.ReLU(),\n",
    "          nn.MaxPool2d(kernel_size=2,\n",
    "                        stride=2)\n",
    "        )\n",
    "        self.conv_block_2 = nn.Sequential(\n",
    "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
    "          nn.ReLU(),\n",
    "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
    "          nn.ReLU(),\n",
    "          nn.MaxPool2d(2)\n",
    "        )\n",
    "        self.classifier = nn.Sequential(\n",
    "          nn.Flatten(),\n",
    "          # Where did this in_features shape come from? \n",
    "          # It's because each layer of our network compresses and changes the shape of our inputs data.\n",
    "          nn.Linear(in_features=hidden_units*13*13,\n",
    "                    out_features=output_shape)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x: torch.Tensor):\n",
    "        x = self.conv_block_1(x)\n",
    "        x = self.conv_block_2(x)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "        # return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ROZsju1d9Cbu"
   },
   "source": [
    "Now let's create an instance of `TinyVGG` and put it on the target device.\n",
    "\n",
    "> **Note:** If you're using Google Colab, and you'd like to use a GPU (recommended), you can turn one on via going to Runtime -> Change runtime type -> Hardware accelerator -> GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tXOOAEHHc6-L",
    "outputId": "3207fda6-896b-4cb5-9f08-f1e46879bf75"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TinyVGG(\n",
       "  (conv_block_1): Sequential(\n",
       "    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv_block_2): Sequential(\n",
       "    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Flatten(start_dim=1, end_dim=-1)\n",
       "    (1): Linear(in_features=1690, out_features=3, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# Instantiate an instance of the model\n",
    "torch.manual_seed(42)\n",
    "model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) \n",
    "                  hidden_units=10, \n",
    "                  output_shape=len(train_data.classes)).to(device)\n",
    "model_0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "C1CwvPiMcEfZ"
   },
   "source": [
    "Let's check out our model by doing a dummy forward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "__p5G2ArcFko",
    "outputId": "341f2b73-a0fa-4fdb-f458-142510dcfcca"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Single image shape: torch.Size([1, 3, 64, 64])\n",
      "\n",
      "Output logits:\n",
      "tensor([[ 0.0208, -0.0019,  0.0095]], device='cuda:0')\n",
      "\n",
      "Output prediction probabilities:\n",
      "tensor([[0.3371, 0.3295, 0.3333]], device='cuda:0')\n",
      "\n",
      "Output prediction label:\n",
      "tensor([0], device='cuda:0')\n",
      "\n",
      "Actual label:\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "# 1. Get a batch of images and labels from the DataLoader\n",
    "img_batch, label_batch = next(iter(train_dataloader))\n",
    "\n",
    "# 2. Get a single image from the batch and unsqueeze the image so its shape fits the model\n",
    "img_single, label_single = img_batch[0].unsqueeze(dim=0), label_batch[0]\n",
    "print(f\"Single image shape: {img_single.shape}\\n\")\n",
    "\n",
    "# 3. Perform a forward pass on a single image\n",
    "model_0.eval()\n",
    "with torch.inference_mode():\n",
    "    pred = model_0(img_single.to(device))\n",
    "    \n",
    "# 4. Print out what's happening and convert model logits -> pred probs -> pred label\n",
    "print(f\"Output logits:\\n{pred}\\n\")\n",
    "print(f\"Output prediction probabilities:\\n{torch.softmax(pred, dim=1)}\\n\")\n",
    "print(f\"Output prediction label:\\n{torch.argmax(torch.softmax(pred, dim=1), dim=1)}\\n\")\n",
    "print(f\"Actual label:\\n{label_single}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wSaDm_W7bc3y"
   },
   "source": [
    "### 3.1 Making a model (TinyVGG) (script mode)\n",
    "\n",
    "Over the past few notebooks (notebook 03 and notebook 04), we've built the TinyVGG model a few times.\n",
    "\n",
    "So it makes sense to put the model into its file so we can reuse it again and again.\n",
    "\n",
    "Let's put our `TinyVGG()` model class into a script called `model_builder.py` with the line `%%writefile going_modular/model_builder.py`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "oEH3bklL86hF",
    "outputId": "b3f7e983-e51c-4e35-a991-df5ff99d63e6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting going_modular/model_builder.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile going_modular/model_builder.py\n",
    "\"\"\"\n",
    "Contains PyTorch model code to instantiate a TinyVGG model.\n",
    "\"\"\"\n",
    "import torch\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "class TinyVGG(nn.Module):\n",
    "    \"\"\"Creates the TinyVGG architecture.\n",
    "\n",
    "    Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.\n",
    "    See the original architecture here: https://poloclub.github.io/cnn-explainer/\n",
    "\n",
    "    Args:\n",
    "    input_shape: An integer indicating number of input channels.\n",
    "    hidden_units: An integer indicating number of hidden units between layers.\n",
    "    output_shape: An integer indicating number of output units.\n",
    "    \"\"\"\n",
    "    def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:\n",
    "        super().__init__()\n",
    "        self.conv_block_1 = nn.Sequential(\n",
    "          nn.Conv2d(in_channels=input_shape, \n",
    "                    out_channels=hidden_units, \n",
    "                    kernel_size=3, \n",
    "                    stride=1, \n",
    "                    padding=0),  \n",
    "          nn.ReLU(),\n",
    "          nn.Conv2d(in_channels=hidden_units, \n",
    "                    out_channels=hidden_units,\n",
    "                    kernel_size=3,\n",
    "                    stride=1,\n",
    "                    padding=0),\n",
    "          nn.ReLU(),\n",
    "          nn.MaxPool2d(kernel_size=2,\n",
    "                        stride=2)\n",
    "        )\n",
    "        self.conv_block_2 = nn.Sequential(\n",
    "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
    "          nn.ReLU(),\n",
    "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
    "          nn.ReLU(),\n",
    "          nn.MaxPool2d(2)\n",
    "        )\n",
    "        self.classifier = nn.Sequential(\n",
    "          nn.Flatten(),\n",
    "          # Where did this in_features shape come from? \n",
    "          # It's because each layer of our network compresses and changes the shape of our inputs data.\n",
    "          nn.Linear(in_features=hidden_units*13*13,\n",
    "                    out_features=output_shape)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x: torch.Tensor):\n",
    "        x = self.conv_block_1(x)\n",
    "        x = self.conv_block_2(x)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "        # return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3-WijBWM9K2c"
   },
   "source": [
    "Create an instance of `TinyVGG` (from the script)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "F__hbQwZ9A56",
    "outputId": "4e8abd04-cbc4-4aa8-b71c-3a8ff8617407"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TinyVGG(\n",
       "  (conv_block_1): Sequential(\n",
       "    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv_block_2): Sequential(\n",
       "    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Flatten(start_dim=1, end_dim=-1)\n",
       "    (1): Linear(in_features=1690, out_features=3, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "from going_modular import model_builder\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# Instantiate an instance of the model from the \"model_builder.py\" script\n",
    "torch.manual_seed(42)\n",
    "model_1 = model_builder.TinyVGG(input_shape=3, # number of color channels (3 for RGB) \n",
    "                                hidden_units=10, \n",
    "                                output_shape=len(class_names)).to(device)\n",
    "model_1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vBD-yWcp9QVT"
   },
   "source": [
    "Do a dummy forward pass on `model_1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NCQWeeMj9kHE",
    "outputId": "cb33c4ec-4848-423c-aea3-89a131e07cfd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Single image shape: torch.Size([1, 3, 64, 64])\n",
      "\n",
      "Output logits:\n",
      "tensor([[ 0.0208, -0.0019,  0.0095]], device='cuda:0')\n",
      "\n",
      "Output prediction probabilities:\n",
      "tensor([[0.3371, 0.3295, 0.3333]], device='cuda:0')\n",
      "\n",
      "Output prediction label:\n",
      "tensor([0], device='cuda:0')\n",
      "\n",
      "Actual label:\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "# 1. Get a batch of images and labels from the DataLoader\n",
    "img_batch, label_batch = next(iter(train_dataloader))\n",
    "\n",
    "# 2. Get a single image from the batch and unsqueeze the image so its shape fits the model\n",
    "img_single, label_single = img_batch[0].unsqueeze(dim=0), label_batch[0]\n",
    "print(f\"Single image shape: {img_single.shape}\\n\")\n",
    "\n",
    "# 3. Perform a forward pass on a single image\n",
    "model_1.eval()\n",
    "with torch.inference_mode():\n",
    "    pred = model_1(img_single.to(device))\n",
    "    \n",
    "# 4. Print out what's happening and convert model logits -> pred probs -> pred label\n",
    "print(f\"Output logits:\\n{pred}\\n\")\n",
    "print(f\"Output prediction probabilities:\\n{torch.softmax(pred, dim=1)}\\n\")\n",
    "print(f\"Output prediction label:\\n{torch.argmax(torch.softmax(pred, dim=1), dim=1)}\\n\")\n",
    "print(f\"Actual label:\\n{label_single}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8OB4xIsXcHrp"
   },
   "source": [
    "## 4. Creating `train_step()` and `test_step()` functions and `train()` to combine them  \n",
    "\n",
    "Rather than writing them again, we can reuse the `train_step()` and `test_step()` functions from [notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/#75-create-train-test-loop-functions).\n",
    "\n",
    "The same goes for the `train()` function we created.\n",
    "\n",
    "The only difference here is that these functions have had docstrings added to them in [Google's Python Functions and Methods Style Guide](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods).\n",
    "\n",
    "Let's start by making `train_step()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "t_0zRYZecJZ6"
   },
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "def train_step(model: torch.nn.Module, \n",
    "               dataloader: torch.utils.data.DataLoader, \n",
    "               loss_fn: torch.nn.Module, \n",
    "               optimizer: torch.optim.Optimizer,\n",
    "               device: torch.device) -> Tuple[float, float]:\n",
    "    \"\"\"Trains a PyTorch model for a single epoch.\n",
    "\n",
    "    Turns a target PyTorch model to training mode and then\n",
    "    runs through all of the required training steps (forward\n",
    "    pass, loss calculation, optimizer step).\n",
    "\n",
    "    Args:\n",
    "    model: A PyTorch model to be trained.\n",
    "    dataloader: A DataLoader instance for the model to be trained on.\n",
    "    loss_fn: A PyTorch loss function to minimize.\n",
    "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
    "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
    "\n",
    "    Returns:\n",
    "    A tuple of training loss and training accuracy metrics.\n",
    "    In the form (train_loss, train_accuracy). For example:\n",
    "\n",
    "    (0.1112, 0.8743)\n",
    "    \"\"\"\n",
    "    # Put model in train mode\n",
    "    model.train()\n",
    "\n",
    "    # Setup train loss and train accuracy values\n",
    "    train_loss, train_acc = 0, 0\n",
    "\n",
    "    # Loop through data loader data batches\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        # Send data to target device\n",
    "        X, y = X.to(device), y.to(device)\n",
    "\n",
    "        # 1. Forward pass\n",
    "        y_pred = model(X)\n",
    "\n",
    "        # 2. Calculate  and accumulate loss\n",
    "        loss = loss_fn(y_pred, y)\n",
    "        train_loss += loss.item() \n",
    "\n",
    "        # 3. Optimizer zero grad\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 4. Loss backward\n",
    "        loss.backward()\n",
    "\n",
    "        # 5. Optimizer step\n",
    "        optimizer.step()\n",
    "\n",
    "        # Calculate and accumulate accuracy metric across all batches\n",
    "        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)\n",
    "        train_acc += (y_pred_class == y).sum().item()/len(y_pred)\n",
    "\n",
    "    # Adjust metrics to get average loss and accuracy per batch \n",
    "    train_loss = train_loss / len(dataloader)\n",
    "    train_acc = train_acc / len(dataloader)\n",
    "    return train_loss, train_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1nwAPwNh-f3k"
   },
   "source": [
    "Now we'll do `test_step()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "sr_9AspYcKkW"
   },
   "outputs": [],
   "source": [
    "def test_step(model: torch.nn.Module, \n",
    "              dataloader: torch.utils.data.DataLoader, \n",
    "              loss_fn: torch.nn.Module,\n",
    "              device: torch.device) -> Tuple[float, float]:\n",
    "    \"\"\"Tests a PyTorch model for a single epoch.\n",
    "\n",
    "    Turns a target PyTorch model to \"eval\" mode and then performs\n",
    "    a forward pass on a testing dataset.\n",
    "\n",
    "    Args:\n",
    "    model: A PyTorch model to be tested.\n",
    "    dataloader: A DataLoader instance for the model to be tested on.\n",
    "    loss_fn: A PyTorch loss function to calculate loss on the test data.\n",
    "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
    "\n",
    "    Returns:\n",
    "    A tuple of testing loss and testing accuracy metrics.\n",
    "    In the form (test_loss, test_accuracy). For example:\n",
    "\n",
    "    (0.0223, 0.8985)\n",
    "    \"\"\"\n",
    "    # Put model in eval mode\n",
    "    model.eval() \n",
    "\n",
    "    # Setup test loss and test accuracy values\n",
    "    test_loss, test_acc = 0, 0\n",
    "\n",
    "    # Turn on inference context manager\n",
    "    with torch.inference_mode():\n",
    "        # Loop through DataLoader batches\n",
    "        for batch, (X, y) in enumerate(dataloader):\n",
    "            # Send data to target device\n",
    "            X, y = X.to(device), y.to(device)\n",
    "\n",
    "            # 1. Forward pass\n",
    "            test_pred_logits = model(X)\n",
    "\n",
    "            # 2. Calculate and accumulate loss\n",
    "            loss = loss_fn(test_pred_logits, y)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            # Calculate and accumulate accuracy\n",
    "            test_pred_labels = test_pred_logits.argmax(dim=1)\n",
    "            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))\n",
    "\n",
    "    # Adjust metrics to get average loss and accuracy per batch \n",
    "    test_loss = test_loss / len(dataloader)\n",
    "    test_acc = test_acc / len(dataloader)\n",
    "    return test_loss, test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "weqHDd5b-iGz"
   },
   "source": [
    "And we'll combine `train_step()` and `test_step()` into `train()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "-oze8b6icWE7"
   },
   "outputs": [],
   "source": [
    "from typing import Dict, List\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "def train(model: torch.nn.Module, \n",
    "          train_dataloader: torch.utils.data.DataLoader, \n",
    "          test_dataloader: torch.utils.data.DataLoader, \n",
    "          optimizer: torch.optim.Optimizer,\n",
    "          loss_fn: torch.nn.Module,\n",
    "          epochs: int,\n",
    "          device: torch.device) -> Dict[str, List[float]]:\n",
    "    \"\"\"Trains and tests a PyTorch model.\n",
    "\n",
    "    Passes a target PyTorch models through train_step() and test_step()\n",
    "    functions for a number of epochs, training and testing the model\n",
    "    in the same epoch loop.\n",
    "\n",
    "    Calculates, prints and stores evaluation metrics throughout.\n",
    "\n",
    "    Args:\n",
    "    model: A PyTorch model to be trained and tested.\n",
    "    train_dataloader: A DataLoader instance for the model to be trained on.\n",
    "    test_dataloader: A DataLoader instance for the model to be tested on.\n",
    "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
    "    loss_fn: A PyTorch loss function to calculate loss on both datasets.\n",
    "    epochs: An integer indicating how many epochs to train for.\n",
    "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
    "\n",
    "    Returns:\n",
    "    A dictionary of training and testing loss as well as training and\n",
    "    testing accuracy metrics. Each metric has a value in a list for \n",
    "    each epoch.\n",
    "    In the form: {train_loss: [...],\n",
    "                  train_acc: [...],\n",
    "                  test_loss: [...],\n",
    "                  test_acc: [...]} \n",
    "    For example if training for epochs=2: \n",
    "                 {train_loss: [2.0616, 1.0537],\n",
    "                  train_acc: [0.3945, 0.3945],\n",
    "                  test_loss: [1.2641, 1.5706],\n",
    "                  test_acc: [0.3400, 0.2973]} \n",
    "    \"\"\"\n",
    "    # Create empty results dictionary\n",
    "    results = {\"train_loss\": [],\n",
    "      \"train_acc\": [],\n",
    "      \"test_loss\": [],\n",
    "      \"test_acc\": []\n",
    "    }\n",
    "\n",
    "    # Loop through training and testing steps for a number of epochs\n",
    "    for epoch in tqdm(range(epochs)):\n",
    "        train_loss, train_acc = train_step(model=model,\n",
    "                                          dataloader=train_dataloader,\n",
    "                                          loss_fn=loss_fn,\n",
    "                                          optimizer=optimizer,\n",
    "                                          device=device)\n",
    "        test_loss, test_acc = test_step(model=model,\n",
    "          dataloader=test_dataloader,\n",
    "          loss_fn=loss_fn,\n",
    "          device=device)\n",
    "\n",
    "        # Print out what's happening\n",
    "        print(\n",
    "          f\"Epoch: {epoch+1} | \"\n",
    "          f\"train_loss: {train_loss:.4f} | \"\n",
    "          f\"train_acc: {train_acc:.4f} | \"\n",
    "          f\"test_loss: {test_loss:.4f} | \"\n",
    "          f\"test_acc: {test_acc:.4f}\"\n",
    "        )\n",
    "\n",
    "        # Update results dictionary\n",
    "        results[\"train_loss\"].append(train_loss)\n",
    "        results[\"train_acc\"].append(train_acc)\n",
    "        results[\"test_loss\"].append(test_loss)\n",
    "        results[\"test_acc\"].append(test_acc)\n",
    "\n",
    "    # Return the filled results at the end of the epochs\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tw_nYa4kfQeI"
   },
   "source": [
    "### 4.1 Creating `train_step()` and `test_step()` functions and `train()` to combine them (script mode)   \n",
    "\n",
    "To create a script for `train_step()`, `test_step()` and `train()`, we'll combine their code all into a single cell.\n",
    "\n",
    "We'll then write that cell to a file called `engine.py` because these functions will be the \"engine\" of our training pipeline.\n",
    "\n",
    "We can do so with the magic line `%%writefile going_modular/engine.py`.\n",
    "\n",
    "We'll also make sure to put all the imports we need (`torch`, `typing`, and `tqdm`) at the top of the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "d47bhfY3fVq5",
    "outputId": "41053838-2e28-4702-d833-d8c97eed119b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting going_modular/engine.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile going_modular/engine.py\n",
    "\"\"\"\n",
    "Contains functions for training and testing a PyTorch model.\n",
    "\"\"\"\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import torch\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "def train_step(model: torch.nn.Module, \n",
    "               dataloader: torch.utils.data.DataLoader, \n",
    "               loss_fn: torch.nn.Module, \n",
    "               optimizer: torch.optim.Optimizer,\n",
    "               device: torch.device) -> Tuple[float, float]:\n",
    "    \"\"\"Trains a PyTorch model for a single epoch.\n",
    "\n",
    "    Turns a target PyTorch model to training mode and then\n",
    "    runs through all of the required training steps (forward\n",
    "    pass, loss calculation, optimizer step).\n",
    "\n",
    "    Args:\n",
    "    model: A PyTorch model to be trained.\n",
    "    dataloader: A DataLoader instance for the model to be trained on.\n",
    "    loss_fn: A PyTorch loss function to minimize.\n",
    "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
    "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
    "\n",
    "    Returns:\n",
    "    A tuple of training loss and training accuracy metrics.\n",
    "    In the form (train_loss, train_accuracy). For example:\n",
    "\n",
    "    (0.1112, 0.8743)\n",
    "    \"\"\"\n",
    "    # Put model in train mode\n",
    "    model.train()\n",
    "\n",
    "    # Setup train loss and train accuracy values\n",
    "    train_loss, train_acc = 0, 0\n",
    "\n",
    "    # Loop through data loader data batches\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        # Send data to target device\n",
    "        X, y = X.to(device), y.to(device)\n",
    "\n",
    "        # 1. Forward pass\n",
    "        y_pred = model(X)\n",
    "\n",
    "        # 2. Calculate  and accumulate loss\n",
    "        loss = loss_fn(y_pred, y)\n",
    "        train_loss += loss.item() \n",
    "\n",
    "        # 3. Optimizer zero grad\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 4. Loss backward\n",
    "        loss.backward()\n",
    "\n",
    "        # 5. Optimizer step\n",
    "        optimizer.step()\n",
    "\n",
    "        # Calculate and accumulate accuracy metric across all batches\n",
    "        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)\n",
    "        train_acc += (y_pred_class == y).sum().item()/len(y_pred)\n",
    "\n",
    "    # Adjust metrics to get average loss and accuracy per batch \n",
    "    train_loss = train_loss / len(dataloader)\n",
    "    train_acc = train_acc / len(dataloader)\n",
    "    return train_loss, train_acc\n",
    "\n",
    "def test_step(model: torch.nn.Module, \n",
    "              dataloader: torch.utils.data.DataLoader, \n",
    "              loss_fn: torch.nn.Module,\n",
    "              device: torch.device) -> Tuple[float, float]:\n",
    "    \"\"\"Tests a PyTorch model for a single epoch.\n",
    "\n",
    "    Turns a target PyTorch model to \"eval\" mode and then performs\n",
    "    a forward pass on a testing dataset.\n",
    "\n",
    "    Args:\n",
    "    model: A PyTorch model to be tested.\n",
    "    dataloader: A DataLoader instance for the model to be tested on.\n",
    "    loss_fn: A PyTorch loss function to calculate loss on the test data.\n",
    "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
    "\n",
    "    Returns:\n",
    "    A tuple of testing loss and testing accuracy metrics.\n",
    "    In the form (test_loss, test_accuracy). For example:\n",
    "\n",
    "    (0.0223, 0.8985)\n",
    "    \"\"\"\n",
    "    # Put model in eval mode\n",
    "    model.eval() \n",
    "\n",
    "    # Setup test loss and test accuracy values\n",
    "    test_loss, test_acc = 0, 0\n",
    "\n",
    "    # Turn on inference context manager\n",
    "    with torch.inference_mode():\n",
    "        # Loop through DataLoader batches\n",
    "        for batch, (X, y) in enumerate(dataloader):\n",
    "            # Send data to target device\n",
    "            X, y = X.to(device), y.to(device)\n",
    "\n",
    "            # 1. Forward pass\n",
    "            test_pred_logits = model(X)\n",
    "\n",
    "            # 2. Calculate and accumulate loss\n",
    "            loss = loss_fn(test_pred_logits, y)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            # Calculate and accumulate accuracy\n",
    "            test_pred_labels = test_pred_logits.argmax(dim=1)\n",
    "            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))\n",
    "\n",
    "    # Adjust metrics to get average loss and accuracy per batch \n",
    "    test_loss = test_loss / len(dataloader)\n",
    "    test_acc = test_acc / len(dataloader)\n",
    "    return test_loss, test_acc\n",
    "\n",
    "def train(model: torch.nn.Module, \n",
    "          train_dataloader: torch.utils.data.DataLoader, \n",
    "          test_dataloader: torch.utils.data.DataLoader, \n",
    "          optimizer: torch.optim.Optimizer,\n",
    "          loss_fn: torch.nn.Module,\n",
    "          epochs: int,\n",
    "          device: torch.device) -> Dict[str, List[float]]:\n",
    "    \"\"\"Trains and tests a PyTorch model.\n",
    "\n",
    "    Passes a target PyTorch models through train_step() and test_step()\n",
    "    functions for a number of epochs, training and testing the model\n",
    "    in the same epoch loop.\n",
    "\n",
    "    Calculates, prints and stores evaluation metrics throughout.\n",
    "\n",
    "    Args:\n",
    "    model: A PyTorch model to be trained and tested.\n",
    "    train_dataloader: A DataLoader instance for the model to be trained on.\n",
    "    test_dataloader: A DataLoader instance for the model to be tested on.\n",
    "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
    "    loss_fn: A PyTorch loss function to calculate loss on both datasets.\n",
    "    epochs: An integer indicating how many epochs to train for.\n",
    "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
    "\n",
    "    Returns:\n",
    "    A dictionary of training and testing loss as well as training and\n",
    "    testing accuracy metrics. Each metric has a value in a list for \n",
    "    each epoch.\n",
    "    In the form: {train_loss: [...],\n",
    "              train_acc: [...],\n",
    "              test_loss: [...],\n",
    "              test_acc: [...]} \n",
    "    For example if training for epochs=2: \n",
    "             {train_loss: [2.0616, 1.0537],\n",
    "              train_acc: [0.3945, 0.3945],\n",
    "              test_loss: [1.2641, 1.5706],\n",
    "              test_acc: [0.3400, 0.2973]} \n",
    "    \"\"\"\n",
    "    # Create empty results dictionary\n",
    "    results = {\"train_loss\": [],\n",
    "               \"train_acc\": [],\n",
    "               \"test_loss\": [],\n",
    "               \"test_acc\": []\n",
    "    }\n",
    "\n",
    "    # Loop through training and testing steps for a number of epochs\n",
    "    for epoch in tqdm(range(epochs)):\n",
    "        train_loss, train_acc = train_step(model=model,\n",
    "                                          dataloader=train_dataloader,\n",
    "                                          loss_fn=loss_fn,\n",
    "                                          optimizer=optimizer,\n",
    "                                          device=device)\n",
    "        test_loss, test_acc = test_step(model=model,\n",
    "          dataloader=test_dataloader,\n",
    "          loss_fn=loss_fn,\n",
    "          device=device)\n",
    "\n",
    "        # Print out what's happening\n",
    "        print(\n",
    "          f\"Epoch: {epoch+1} | \"\n",
    "          f\"train_loss: {train_loss:.4f} | \"\n",
    "          f\"train_acc: {train_acc:.4f} | \"\n",
    "          f\"test_loss: {test_loss:.4f} | \"\n",
    "          f\"test_acc: {test_acc:.4f}\"\n",
    "        )\n",
    "\n",
    "        # Update results dictionary\n",
    "        results[\"train_loss\"].append(train_loss)\n",
    "        results[\"train_acc\"].append(train_acc)\n",
    "        results[\"test_loss\"].append(test_loss)\n",
    "        results[\"test_acc\"].append(test_acc)\n",
    "\n",
    "    # Return the filled results at the end of the epochs\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A2nXumITc9DL"
   },
   "source": [
    "## 5. Creating a function to save the model\n",
    "\n",
    "Let's setup a function to save our model to a directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "lu_vyjq-c933"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "def save_model(model: torch.nn.Module,\n",
    "               target_dir: str,\n",
    "               model_name: str):\n",
    "    \"\"\"Saves a PyTorch model to a target directory.\n",
    "\n",
    "    Args:\n",
    "    model: A target PyTorch model to save.\n",
    "    target_dir: A directory for saving the model to.\n",
    "    model_name: A filename for the saved model. Should include\n",
    "      either \".pth\" or \".pt\" as the file extension.\n",
    "\n",
    "    Example usage:\n",
    "    save_model(model=model_0,\n",
    "               target_dir=\"models\",\n",
    "               model_name=\"05_going_modular_tingvgg_model.pth\")\n",
    "    \"\"\"\n",
    "    # Create target directory\n",
    "    target_dir_path = Path(target_dir)\n",
    "    target_dir_path.mkdir(parents=True,\n",
    "                        exist_ok=True)\n",
    "\n",
    "    # Create model save path\n",
    "    assert model_name.endswith(\".pth\") or model_name.endswith(\".pt\"), \"model_name should end with '.pt' or '.pth'\"\n",
    "    model_save_path = target_dir_path / model_name\n",
    "\n",
    "    # Save the model state_dict()\n",
    "    print(f\"[INFO] Saving model to: {model_save_path}\")\n",
    "    torch.save(obj=model.state_dict(),\n",
    "             f=model_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F7SnPGOfc_Wc"
   },
   "source": [
    "### 5.1 Creating a function to save the model (script mode)\n",
    "\n",
    "How about we add our `save_model()` function to a script called `utils.py` which is short for \"utilities\".\n",
    "\n",
    "We can do so with the magic line `%%writefile going_modular/utils.py`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "20LC4xwN_2ZT",
    "outputId": "30852196-8138-4f9e-9f43-578dc6277682"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting going_modular/utils.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile going_modular/utils.py\n",
    "\"\"\"\n",
    "Contains various utility functions for PyTorch model training and saving.\n",
    "\"\"\"\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "\n",
    "def save_model(model: torch.nn.Module,\n",
    "               target_dir: str,\n",
    "               model_name: str):\n",
    "    \"\"\"Saves a PyTorch model to a target directory.\n",
    "\n",
    "    Args:\n",
    "    model: A target PyTorch model to save.\n",
    "    target_dir: A directory for saving the model to.\n",
    "    model_name: A filename for the saved model. Should include\n",
    "      either \".pth\" or \".pt\" as the file extension.\n",
    "\n",
    "    Example usage:\n",
    "    save_model(model=model_0,\n",
    "               target_dir=\"models\",\n",
    "               model_name=\"05_going_modular_tingvgg_model.pth\")\n",
    "    \"\"\"\n",
    "    # Create target directory\n",
    "    target_dir_path = Path(target_dir)\n",
    "    target_dir_path.mkdir(parents=True,\n",
    "                        exist_ok=True)\n",
    "\n",
    "    # Create model save path\n",
    "    assert model_name.endswith(\".pth\") or model_name.endswith(\".pt\"), \"model_name should end with '.pt' or '.pth'\"\n",
    "    model_save_path = target_dir_path / model_name\n",
    "\n",
    "    # Save the model state_dict()\n",
    "    print(f\"[INFO] Saving model to: {model_save_path}\")\n",
    "    torch.save(obj=model.state_dict(),\n",
    "             f=model_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "20ZO64U2cZY_"
   },
   "source": [
    "## 6. Train, evaluate and save the model\n",
    "\n",
    "Let's leverage the functions we've got above to train, test and save a model to file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 173,
     "referenced_widgets": [
      "84317b054c244ee4b1bed7b2f86660b7",
      "a0a72aed7e2743f8a95c6af3b6e4a457",
      "b517f855551f4c53969677161952a1ac",
      "4d780e75e5d24241b1f6c0203a8bd73e",
      "a0caf8f4760c4935b5b9be86566ee8e0",
      "83cf71b79abd4fc5ae9daaac8f586841",
      "3402f79d3df643efbed25c72aac94642",
      "1e79615e722f456189e07874d13564d2",
      "bac27d87b4d54b6cbaf85dd7b7ecdc64",
      "f9b7fbae53e84883ac281343bb6a9da9",
      "4bc100911d7d4303b6ab7efb59d3420c"
     ]
    },
    "id": "FjeRMiLjccVc",
    "outputId": "2d9005a4-afce-4230-c758-1c6efa01a5e8"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "205f478a42ca42f2ab32113d597dc388",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | train_loss: 1.0956 | train_acc: 0.3867 | test_loss: 1.0630 | test_acc: 0.4133\n",
      "Epoch: 2 | train_loss: 1.0141 | train_acc: 0.5111 | test_loss: 1.0210 | test_acc: 0.4533\n",
      "Epoch: 3 | train_loss: 0.9591 | train_acc: 0.5644 | test_loss: 0.9961 | test_acc: 0.4400\n",
      "Epoch: 4 | train_loss: 0.8994 | train_acc: 0.5778 | test_loss: 0.9986 | test_acc: 0.4533\n",
      "Epoch: 5 | train_loss: 0.8652 | train_acc: 0.6267 | test_loss: 1.0010 | test_acc: 0.5467\n",
      "[INFO] Total training time: 5.461 seconds\n",
      "[INFO] Saving model to: models/05_going_modular_cell_mode_tinyvgg_model.pth\n"
     ]
    }
   ],
   "source": [
    "# Set random seeds\n",
    "torch.manual_seed(42) \n",
    "torch.cuda.manual_seed(42)\n",
    "\n",
    "# Set number of epochs\n",
    "NUM_EPOCHS = 5\n",
    "\n",
    "# Recreate an instance of TinyVGG\n",
    "model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) \n",
    "                  hidden_units=10, \n",
    "                  output_shape=len(train_data.classes)).to(device)\n",
    "\n",
    "# Setup loss function and optimizer\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(params=model_0.parameters(), lr=0.001)\n",
    "\n",
    "# Start the timer\n",
    "from timeit import default_timer as timer \n",
    "start_time = timer()\n",
    "\n",
    "# Train model_0 \n",
    "model_0_results = train(model=model_0, \n",
    "                        train_dataloader=train_dataloader,\n",
    "                        test_dataloader=test_dataloader,\n",
    "                        optimizer=optimizer,\n",
    "                        loss_fn=loss_fn, \n",
    "                        epochs=NUM_EPOCHS,\n",
    "                        device=device)\n",
    "\n",
    "# End the timer and print out how long it took\n",
    "end_time = timer()\n",
    "print(f\"[INFO] Total training time: {end_time-start_time:.3f} seconds\")\n",
    "\n",
    "# Save the model\n",
    "save_model(model=model_0,\n",
    "           target_dir=\"models\",\n",
    "           model_name=\"05_going_modular_cell_mode_tinyvgg_model.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Md_CmXTvfkOz"
   },
   "source": [
    "### 6.1 Train, evaluate and save the model (script mode)\n",
    "\n",
    "Let's combine all of our modular files into a single script `train.py`.\n",
    "\n",
    "This will allow us to run all of the functions we've written with a single line of code on the command line:\n",
    "\n",
    "`python going_modular/train.py`\n",
    "\n",
    "Or if we're running it in a notebook:\n",
    "\n",
    "`!python going_modular/train.py`\n",
    "\n",
    "We'll go through the following steps:\n",
    "1. Import the various dependencies, namely `torch`, `os`, `torchvision.transforms` and all of the scripts from the `going_modular` directory, `data_setup`, `engine`, `model_builder`, `utils`.\n",
    "  * **Note:** Since `train.py` will be *inside* the `going_modular` directory, we can import the other modules via `import ...` rather than `from going_modular import ...`.\n",
    "2. Setup various hyperparameters such as batch size, number of epochs, learning rate and number of hidden units (these could be set in the future via [Python's `argparse`](https://docs.python.org/3/library/argparse.html)).\n",
    "3. Setup the training and test directories.\n",
    "4. Setup device-agnostic code.\n",
    "5. Create the necessary data transforms.\n",
    "6. Create the DataLoaders using `data_setup.py`.\n",
    "7. Create the model using `model_builder.py`.\n",
    "8. Setup the loss function and optimizer.\n",
    "9. Train the model using `engine.py`.\n",
    "10. Save the model using `utils.py`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8XnG2l4zf2Ei",
    "outputId": "84f67ba0-ddaf-45a2-9a71-a37cf3d1ca88"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting going_modular/train.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile going_modular/train.py\n",
    "\"\"\"\n",
    "Trains a PyTorch image classification model using device-agnostic code.\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "\n",
    "import torch\n",
    "\n",
    "from torchvision import transforms\n",
    "\n",
    "import data_setup, engine, model_builder, utils\n",
    "\n",
    "\n",
    "# Setup hyperparameters\n",
    "NUM_EPOCHS = 5\n",
    "BATCH_SIZE = 32\n",
    "HIDDEN_UNITS = 10\n",
    "LEARNING_RATE = 0.001\n",
    "\n",
    "# Setup directories\n",
    "train_dir = \"data/pizza_steak_sushi/train\"\n",
    "test_dir = \"data/pizza_steak_sushi/test\"\n",
    "\n",
    "# Setup target device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# Create transforms\n",
    "data_transform = transforms.Compose([\n",
    "  transforms.Resize((64, 64)),\n",
    "  transforms.ToTensor()\n",
    "])\n",
    "\n",
    "# Create DataLoaders with help from data_setup.py\n",
    "train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(\n",
    "    train_dir=train_dir,\n",
    "    test_dir=test_dir,\n",
    "    transform=data_transform,\n",
    "    batch_size=BATCH_SIZE\n",
    ")\n",
    "\n",
    "# Create model with help from model_builder.py\n",
    "model = model_builder.TinyVGG(\n",
    "    input_shape=3,\n",
    "    hidden_units=HIDDEN_UNITS,\n",
    "    output_shape=len(class_names)\n",
    ").to(device)\n",
    "\n",
    "# Set loss and optimizer\n",
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(),\n",
    "                             lr=LEARNING_RATE)\n",
    "\n",
    "# Start training with help from engine.py\n",
    "engine.train(model=model,\n",
    "             train_dataloader=train_dataloader,\n",
    "             test_dataloader=test_dataloader,\n",
    "             loss_fn=loss_fn,\n",
    "             optimizer=optimizer,\n",
    "             epochs=NUM_EPOCHS,\n",
    "             device=device)\n",
    "\n",
    "# Save the model with help from utils.py\n",
    "utils.save_model(model=model,\n",
    "                 target_dir=\"models\",\n",
    "                 model_name=\"05_going_modular_script_mode_tinyvgg_model.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HHnPm-w7CjXT"
   },
   "source": [
    "Now our final directory structure looks like:\n",
    "```\n",
    "data/\n",
    "  pizza_steak_sushi/\n",
    "    train/\n",
    "      pizza/\n",
    "        train_image_01.jpeg\n",
    "        train_image_02.jpeg\n",
    "        ...\n",
    "      steak/\n",
    "      sushi/\n",
    "    test/\n",
    "      pizza/\n",
    "        test_image_01.jpeg\n",
    "        test_image_02.jpeg\n",
    "        ...\n",
    "      steak/\n",
    "      sushi/\n",
    "going_modular/\n",
    "  data_setup.py\n",
    "  engine.py\n",
    "  model_builder.py\n",
    "  train.py\n",
    "  utils.py\n",
    "models/\n",
    "  saved_model.pth\n",
    "```\n",
    "\n",
    "Now to put it all together!\n",
    "\n",
    "Let's run our `train.py` file from the command line with:\n",
    "\n",
    "```\n",
    "!python going_modular/train.py\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6eVtanbSCjFj",
    "outputId": "9c90d0d8-9b44-4558-b74d-292729eca1e4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0%|                                                     | 0/5 [00:00<?, ?it/s]Epoch: 1 | train_loss: 1.1131 | train_acc: 0.2852 | test_loss: 1.1138 | test_acc: 0.2604\n",
      " 20%|█████████                                    | 1/5 [00:01<00:04,  1.06s/it]Epoch: 2 | train_loss: 1.0851 | train_acc: 0.4102 | test_loss: 1.1238 | test_acc: 0.1979\n",
      " 40%|██████████████████                           | 2/5 [00:01<00:02,  1.20it/s]Epoch: 3 | train_loss: 1.0837 | train_acc: 0.4141 | test_loss: 1.1459 | test_acc: 0.1979\n",
      " 60%|███████████████████████████                  | 3/5 [00:02<00:01,  1.33it/s]Epoch: 4 | train_loss: 1.1104 | train_acc: 0.2930 | test_loss: 1.1318 | test_acc: 0.1979\n",
      " 80%|████████████████████████████████████         | 4/5 [00:03<00:00,  1.40it/s]Epoch: 5 | train_loss: 1.0833 | train_acc: 0.2930 | test_loss: 1.0883 | test_acc: 0.3712\n",
      "100%|█████████████████████████████████████████████| 5/5 [00:03<00:00,  1.35it/s]\n",
      "[INFO] Saving model to: models/05_going_modular_script_mode_tinyvgg_model.pth\n"
     ]
    }
   ],
   "source": [
    "!python going_modular/train.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Woah!\n",
    "\n",
    "Look at that!\n",
    "\n",
    "We've just trained a model with a single line of code from the command line.\n",
    "\n",
    "We wrote a fair of code to do so, however, now we've got our code in `.py` files we can import them and reuse them as much as we like.\n",
    "\n",
    "For exercises and extra-curriculum for this section, refer to the [online book version of 05. PyTorch Going Modular](https://www.learnpytorch.io/05_pytorch_going_modular/#exercises)."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyOtemTdOJk4sI2quxZNvUuS",
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "05_pytorch_going_modular_script_mode.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "1e79615e722f456189e07874d13564d2": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "3402f79d3df643efbed25c72aac94642": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "4bc100911d7d4303b6ab7efb59d3420c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "4d780e75e5d24241b1f6c0203a8bd73e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f9b7fbae53e84883ac281343bb6a9da9",
      "placeholder": "​",
      "style": "IPY_MODEL_4bc100911d7d4303b6ab7efb59d3420c",
      "value": " 5/5 [00:11&lt;00:00,  2.23s/it]"
     }
    },
    "83cf71b79abd4fc5ae9daaac8f586841": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "84317b054c244ee4b1bed7b2f86660b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_a0a72aed7e2743f8a95c6af3b6e4a457",
       "IPY_MODEL_b517f855551f4c53969677161952a1ac",
       "IPY_MODEL_4d780e75e5d24241b1f6c0203a8bd73e"
      ],
      "layout": "IPY_MODEL_a0caf8f4760c4935b5b9be86566ee8e0"
     }
    },
    "a0a72aed7e2743f8a95c6af3b6e4a457": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_83cf71b79abd4fc5ae9daaac8f586841",
      "placeholder": "​",
      "style": "IPY_MODEL_3402f79d3df643efbed25c72aac94642",
      "value": "100%"
     }
    },
    "a0caf8f4760c4935b5b9be86566ee8e0": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b517f855551f4c53969677161952a1ac": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1e79615e722f456189e07874d13564d2",
      "max": 5,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_bac27d87b4d54b6cbaf85dd7b7ecdc64",
      "value": 5
     }
    },
    "bac27d87b4d54b6cbaf85dd7b7ecdc64": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "f9b7fbae53e84883ac281343bb6a9da9": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
